Figure 4 - Supervised Clustering Adult Census Data


In [1]:
import xgboost
import shap

# load JS visualization code to notebook
shap.initjs() 

# train XGBoost model
X,y = shap.datasets.adult()
X_display,_ = shap.datasets.adult(display=True)
bst = xgboost.train(
    {"learning_rate": 0.005, "max_depth": 3, "objective": "binary:logistic", "base_score": y.mean()},
    xgboost.DMatrix(X, label=y),
    500
)

# explain the model's predictions using SHAP values
explainer = shap.TreeExplainer(bst)
shap_values = explainer.shap_values(X)


Note that the plot below is not exactly the same as the paper's figure since it is using a different random subset of people than was originally used (we are using shap.datasets.adult() for easy loading here).


In [2]:
# visualize the first 2000 predictions
shap.force_plot(explainer.expected_value, shap_values[:2000,:], X_display.iloc[:2000,:], link="logit")


Out[2]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

In [ ]: